{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7717416e",
   "metadata": {},
   "source": [
    "***本项目是用来演示使用中文菜谱数据调优ChatGLM2-6B***"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc142f32-4a99-47a6-b4af-26805dc73eb8",
   "metadata": {},
   "source": [
    "### 1、加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3d72b7f8",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.02333235740661621,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 7,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fea2a875a75f41c9bfb7e64b1cf9b442",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "from transformers import  AutoModel,AutoTokenizer, BitsAndBytesConfig\n",
    "import torch\n",
    "model_name = \"./models/chatglm2-6b\" #  #或者远程 “THUDM/chatglm2-6b”\n",
    "# .cache/huggingface/hub/models--THUDM--chatglm2-6b/snapshots/4e38bef4c028beafc8fb1837462f74c02e68fcc2/\n",
    "\n",
    "bnb_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_compute_dtype=torch.float16,\n",
    "    bnb_4bit_use_double_quant=True, #QLoRA 设计的 Double Quantization\n",
    "    bnb_4bit_quant_type='nf4', # QLoRA 设计的 Normal Float 4 量化数据类型\n",
    "    llm_int8_threshold=6.0,\n",
    "    llm_int8_has_fp16_weight=False,\n",
    ")\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
    "model = AutoModel.from_pretrained(model_name,\n",
    "                                  quantization_config=bnb_config,\n",
    "                                  trust_remote_code=True) #.half().cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "da72abc9-903a-4273-9e97-2c87772ea429",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用 Markdown 格式打印模型输出\n",
    "from IPython.display import display, Markdown, clear_output\n",
    "\n",
    "def display_answer(query, history=[]):\n",
    "    for response, history in model.stream_chat(\n",
    "            tokenizer, query, history=history):\n",
    "        clear_output(wait=True)\n",
    "        display(Markdown(response))\n",
    "    return history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3f3d97e7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "我是一个名为 ChatGLM2-6B 的人工智能助手，是基于清华大学 KEG 实验室和智谱 AI 公司于 2023 年共同训练的语言模型开发的。我的任务是针对用户的问题和要求提供适当的答复和支持。"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 测试当前模型的能力\n",
    "history = display_answer(\"你叫什么\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37236f5a",
   "metadata": {},
   "source": [
    "我们以自己收集的中文菜谱数据为例，介绍如何让大语言模型ChatGLM拥有新的技能。\n",
    "通常来讲调优模型有以下几种方式：\n",
    "1. 重新在新的数据集上训练模型，让模型具有相关的能力；\n",
    "2. 使用预训练（pretrained）的模型为基础，在新的数据集上fine-tuning，使得模型具有新的技能；\n",
    "3. 使用参数有效的调优策略（Parameter-Efficient Fine-Tuning，peft）给预训练模型打补丁，以补丁的形式让模型具备新的技能。\n",
    "\n",
    "通常来说，对于1和2适用于小模型，一般是亿级以下参数量的模型。对于ChatGLM2-6b模型，其参数量为60亿，采用1和2的方式调优模型，一般会导致GPU的显存不够，造成OOM错误。因此，本示例采用3来调优ChatGLM2-6b模型。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bed2b2b",
   "metadata": {},
   "source": [
    "### 2、数据准备"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53a832a6",
   "metadata": {},
   "source": [
    "#### 2.1 处理数据，提取必要的信息"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f184db15",
   "metadata": {},
   "source": [
    "数据总共包括136048条菜谱数据，数据来自网友在https://www.meishichina.com/ 抓取的结果。\n",
    "![](https://enpei-md.oss-cn-hangzhou.aliyuncs.com/img202307241450198.png?x-oss-process=style/wp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32613b75-79c8-4e4f-8afa-4681a067463a",
   "metadata": {},
   "source": [
    "数据格式如下：\n",
    "```json\n",
    "{\n",
    "  \"id\": \"1\",\n",
    "  \"title\": \"红烧鸡翅\",\n",
    "  \"intro\": \"\",\n",
    "  \"image\": \"http://i8.meishichina.com/attachment/recipe/200910/200910120907019.jpg@!p800\",\n",
    "  \"steps\": [\n",
    "    {\n",
    "      \"index\": 1,\n",
    "      \"image\": \"\",\n",
    "      \"content\": \"鸡翅洗净抹干水分，加入腌料拌匀，腌制1小时\"\n",
    "    },\n",
    "    {\n",
    "      \"index\": 2,\n",
    "      \"image\": \"\",\n",
    "      \"content\": \"牛腩切3厘米左右的块；土豆一半切1厘米小块另一半切3厘米左右的块；洋葱切碎，至于是切块还是切条切丝，大家随意啊；胡萝卜切的小一点，会入味，甜甜的。\"\n",
    "    },\n",
    "    {\n",
    "      \"index\": 3,\n",
    "      \"image\": \"\",\n",
    "      \"content\": \"很疑惑\"\n",
    "    }\n",
    "  ],\n",
    "  \"ingredients\": {\n",
    "    \"鸡翅中\": \"8个\",\n",
    "    \"姜（腌料）\": \"2片\",\n",
    "    \"葱（腌料）\": \"2根\",\n",
    "    \"盐（腌料）\": \"4克\",\n",
    "    \"料酒（调料A）\": \"半汤勺\",\n",
    "    \"酱油（调料A）\": \"1汤勺\",\n",
    "    \"胡椒粉（调料A）\": \"少许\",\n",
    "    \"蚝油（调料B）\": \"2汤勺\",\n",
    "    \"糖（调料B）\": \"1茶勺\",\n",
    "    \"麻油（调料B）\": \"少许\"\n",
    "  },\n",
    "  \"tags\": [],\n",
    "  \"notice\": \"特点：色泽酱红，鲜香酥嫩。\\n小提示：挑选鸡翅时，为了受热均匀，最好全部选用鸡翅中段。\",\n",
    "  \"level\": \"普通\",\n",
    "  \"craft\": \"烧\",\n",
    "  \"duration\": \"一小时\",\n",
    "  \"flavor\": \"原味\"\n",
    "}\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bba421b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import tqdm\n",
    "# 处理食谱数据\n",
    "def process_mstx(data_file):\n",
    "    # 打开JSON文件\n",
    "    with open(data_file, 'r', encoding='utf-8-sig') as fd:\n",
    "        meishi_json = json.load(fd)\n",
    "    \n",
    "    print(f\"总共有{len(meishi_json)}个菜品\")\n",
    "    \n",
    "    data = []\n",
    "    # 遍历数据，拼接字符\n",
    "    for food in tqdm.tqdm(meishi_json):\n",
    "        # 食品名称\n",
    "        food_name = food[\"title\"]\n",
    "        # 食材明细\n",
    "        ingredient = \"\"\n",
    "        for k,v in food[\"ingredients\"].items():\n",
    "            ingredient += f\"{k} : {v} \\n\"\n",
    "        # 制作步骤\n",
    "        step = \"\"\n",
    "        for st in food[\"steps\"]:\n",
    "            step += f\"第{st['index']}步：{st['content']}\"\n",
    "            \n",
    "        #制作方法\n",
    "        craft = \"\" if food['craft'] is None else food['craft']\n",
    "        duration = \"\" if food['duration'] is None else food['duration']\n",
    "        method = craft + duration\n",
    "        \n",
    "        # 构建数据\n",
    "        data.append({\n",
    "            \"id\": food['id'],\n",
    "            \"菜品名称\": food_name,\n",
    "            \"食材明细\": ingredient,\n",
    "            \"制作步骤\": step,\n",
    "            \"制作方法\": method,\n",
    "        })\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1f62a4ec",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "总共有136048个菜品\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 136048/136048 [00:00<00:00, 189588.18it/s]\n"
     ]
    }
   ],
   "source": [
    "data_file = \"./mstx-中文菜谱.json\"\n",
    "meishi_data = process_mstx(data_file)\n",
    "\n",
    "# 保存处理后的数据\n",
    "processed_data_file = 'mstx-中文菜谱-processed.json'\n",
    "with open(processed_data_file, 'w', encoding='utf-8') as fd:\n",
    "        fd.write(json.dumps(meishi_data, indent=4, ensure_ascii=False))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "028107f7-8479-4c90-a6e8-14ad9fa52551",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'id': '1',\n",
       " '菜品名称': '红烧鸡翅',\n",
       " '食材明细': '鸡翅中 : 8个 \\n姜（腌料） : 2片 \\n葱（腌料） : 2根 \\n盐（腌料） : 4克 \\n料酒（调料A） : 半汤勺 \\n酱油（调料A） : 1汤勺 \\n胡椒粉（调料A） : 少许 \\n蚝油（调料B） : 2汤勺 \\n糖（调料B） : 1茶勺 \\n麻油（调料B） : 少许 \\n',\n",
       " '制作步骤': '第1步：鸡翅洗净抹干水分，加入腌料拌匀，腌制1小时第2步：牛腩切3厘米左右的块；土豆一半切1厘米小块另一半切3厘米左右的块；洋葱切碎，至于是切块还是切条切丝，大家随意啊；胡萝卜切的小一点，会入味，甜甜的。第3步：很疑惑',\n",
       " '制作方法': '烧一小时'}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 打印一个示例的完整信息\n",
    "meishi_data[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c82f111d-a0bc-419b-838f-c2b0062d31f5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "红烧鸡翅\n",
      "萝卜丝鲫鱼汤\n",
      "沙茶牛肉\n",
      "五彩鳝丝\n",
      "香菇鲜肉盏\n",
      "碧绿虾仁\n",
      "剁椒蒸鱼头\n",
      "黑椒牛柳\n",
      "蜜橘鸡丁\n",
      "土家一罐香\n"
     ]
    }
   ],
   "source": [
    "# 打印10个示例\n",
    "for i in range(10):\n",
    "    print(meishi_data[i]['菜品名称'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2bddab02-f5cb-472b-9b7f-b9a00010cf1c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "萝卜丝鲫鱼汤是一道常见的中式汤品,通常使用萝卜、鲫鱼等食材制作而成。以下是一种简单的萝卜丝鲫鱼汤的食谱:\n",
       "\n",
       "所需食材:\n",
       "\n",
       "- 鲫鱼(1条约100克)\n",
       "- 萝卜(2根)\n",
       "- 葱姜适量\n",
       "- 食用油\n",
       "- 盐、胡椒粉、料酒各适量\n",
       "\n",
       "步骤:\n",
       "\n",
       "1.将鲫鱼去鳞、去骨,切成丝状;萝卜去皮,切成细丝。\n",
       "\n",
       "2.锅中倒入适量食用油,加入葱姜爆香。\n",
       "\n",
       "3.将鲫鱼丝放入锅中翻炒,加入适量盐和胡椒粉,料酒,翻炒均匀。\n",
       "\n",
       "4.倒入适量水,加入萝卜丝,煮沸。\n",
       "\n",
       "5.转小火,盖上锅盖,煮约10分钟,直到鱼和汤变稠。\n",
       "\n",
       "6.最后,根据口味加盐或糖调味即可。\n",
       "\n",
       "萝卜丝鲫鱼汤汤汁鲜美,鱼肉细腻,适合营养丰富,是一道美味的家常汤品。"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 测试当前模型的能力\n",
    "history = []\n",
    "history = display_answer(\"萝卜丝鲫鱼汤\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9a21913",
   "metadata": {},
   "source": [
    "#### 2.2 构建训练数据的prompt\n",
    "\n",
    "在本示例中，我们构建2种类型的prompt，即：\n",
    "1. 给定菜名，生成制作方法+食材+步骤\n",
    "2. 给出制作方法+食材+步骤，生成菜名；\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "92b32be5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import random\n",
    "import datasets\n",
    "import pandas as pd\n",
    "random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "439bca7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_foodname_methods_prompt(data_file, prompt = \"\"):\n",
    "    '''\n",
    "    给菜名，生成制作方法+步骤+食材\n",
    "    '''\n",
    "    # 读取文件\n",
    "    with open(data_file, \"r\", encoding='utf8') as f:\n",
    "        caipu_json = json.load(f)\n",
    "    print(f'总共有{len(caipu_json)}个菜品')\n",
    "\n",
    "    data = []\n",
    "    for caiming in caipu_json:\n",
    "        # print(caiming['id'])\n",
    "        # 菜品名称\n",
    "        food_name = caiming[\"菜品名称\"]\n",
    "\n",
    "        # 食材明细\n",
    "        ingredient = caiming[\"食材明细\"]\n",
    "\n",
    "        # 制作步骤\n",
    "        step = caiming[\"制作步骤\"]\n",
    "\n",
    "        # 制作方法\n",
    "        method = caiming[\"制作方法\"]\n",
    "\n",
    "        # 构建prompt\n",
    "        prompt_item = {'prompt' : prompt + food_name, 'response': '\\n' + '食材明细: \\n' + ingredient + '\\n' + \"制作步骤: \\n\" +  step + '\\n' +  \"制作方法: \\n\" + method + '\\n'}\n",
    "        data.append(prompt_item)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d5a9f4ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_methods_foodname_prompt(data_file, prompt=\"\"):\n",
    "    '''\n",
    "    给食材+制作方法+步骤，生成菜名\n",
    "    '''\n",
    "    with open(data_file, \"r\", encoding='utf8') as f:\n",
    "        caipu_json = json.load(f)\n",
    "    print(f'总共有{len(caipu_json)}个菜品')\n",
    "\n",
    "    data = []\n",
    "    for caiming in caipu_json:\n",
    "        # print(caiming['id'])\n",
    "        # 菜品名称\n",
    "        food_name = caiming[\"菜品名称\"]\n",
    "\n",
    "        # 食材明细\n",
    "        ingredient = caiming[\"食材明细\"]\n",
    "\n",
    "        # 制作步骤\n",
    "        step = caiming[\"制作步骤\"]\n",
    "\n",
    "        # 制作方法\n",
    "        method = caiming[\"制作方法\"]\n",
    "\n",
    "        # 构建prompt\n",
    "        prompt_item = {'prompt': prompt + '\\n' + '食材明细: \\n' + ingredient + '\\n' + \"制作步骤: \\n\" + step + '\\n' + \"制作方法: \\n\" + method + '\\n', 'response': \"以上步骤是菜品 (\" + food_name + \") 的制作方法 \\n\"}\n",
    "        data.append(prompt_item)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c8cd9cf2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_and_concate(data_file):\n",
    "    '''\n",
    "    划分训练集和测试集，按照各自数据的8：2划分\n",
    "    '''\n",
    "    # 根据菜名生成细节\n",
    "    food_data = build_foodname_methods_prompt(data_file)\n",
    "    print(f'一共产生数据: {len(food_data)} 条')\n",
    "    # 转为pandas格式\n",
    "    food_data = pd.DataFrame(food_data)\n",
    "    # 转为torch dataset\n",
    "    food_data_ds = datasets.Dataset.from_pandas(food_data).train_test_split(test_size=0.2, shuffle=True, seed=42)\n",
    "    \n",
    "    # 根据细节生成菜名\n",
    "    method_data = build_methods_foodname_prompt(data_file)\n",
    "    print(f'一共产生数据: {len(method_data)} 条')\n",
    "    method_data = pd.DataFrame(method_data)\n",
    "    method_data_ds = datasets.Dataset.from_pandas(method_data).train_test_split(test_size=0.2, shuffle=True, seed=42)\n",
    "    \n",
    "    \n",
    "    # 拼接数据集\n",
    "    train_data = pd.concat([food_data_ds['train'].to_pandas(),method_data_ds['train'].to_pandas()])\n",
    "    test_data = pd.concat([food_data_ds['test'].to_pandas(), method_data_ds['test'].to_pandas()])\n",
    "    \n",
    "    # 返回数据\n",
    "    return train_data, test_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d8d78d8e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "总共有136048个菜品\n",
      "一共产生数据: 136048 条\n",
      "总共有136048个菜品\n",
      "一共产生数据: 136048 条\n",
      "(217676, 2)\n",
      "(54420, 2)\n",
      "                   prompt                                           response\n",
      "0                   辣拌脆黄瓜  \\n食材明细: \\n黄瓜 : 320克 \\n生姜 : 1 块 \\n独蒜 : 1 个 \\n干红...\n",
      "1               茄汁肉末鸡蛋葱花饼  \\n食材明细: \\n肉末 : 适量 \\n白面 : 适量 \\n鸡蛋 : 适量 \\n洋葱 : 适...\n",
      "2                 香蕉牛奶西米捞  \\n食材明细: \\n西米 : 适量 \\n牛奶 : 250ml \\n香蕉 : 一根 \\n枸杞 ...\n",
      "3                   酥粒桑葚卷  \\n食材明细: \\n高筋粉 : 140克 \\n水 : 80克 \\n细砂糖 : 20克 \\n黄...\n",
      "4                  绿豆糕（一）  \\n食材明细: \\n绿豆面 : 200克 \\n蜂蜜水 : 120ml \\n\\n制作步骤: \\...\n",
      "5                  大白菜烩千张  \\n食材明细: \\n大白菜 : 200g \\n千张 : 1张 \\n油 : 适量 \\n姜丝 :...\n",
      "6                 咖喱香茅煮花蛤  \\n食材明细: \\n花蛤 : 300g \\n洋葱 : 半个 \\n香茅 : 2根 \\n咖喱 :...\n",
      "7                 什锦鱼丸海带汤  \\n食材明细: \\n鱼丸 : 150g \\n海带 : 200g \\n葱花 : 适量 \\n盐 ...\n",
      "8  每个女孩心中都有的小资情结——橘子酱夹心饼干  \\n食材明细: \\n低筋面粉 : 70克 \\n黄油 : 100克 \\n糖粉 : 80克 \\n...\n",
      "9        【低卡系列之一】芝麻核桃胚芽面包  \\n食材明细: \\n面包粉 : 500克 \\n芝麻核桃粉 : 60克 \\n小麦胚芽 : 50...\n"
     ]
    }
   ],
   "source": [
    "# 将数据按照相应的prompt构建，并划分训练数据和测试数据\n",
    "processed_data_file = 'mstx-中文菜谱-processed.json'\n",
    "train_data, test_data = split_and_concate(processed_data_file)\n",
    "# 打印大小\n",
    "print(train_data.shape)\n",
    "print(test_data.shape)\n",
    "print(train_data.head(10))\n",
    "\n",
    "# 将pandas格式数据序列化保存\n",
    "train_data.to_parquet('train_data.parquet')\n",
    "test_data.to_parquet('test_data.parquet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "49f99c76",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 转为torch dataset\n",
    "ds_train = datasets.Dataset.from_pandas(train_data)\n",
    "ds_test = datasets.Dataset.from_pandas(test_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0a43f28",
   "metadata": {},
   "source": [
    "### 3. token编码"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c4f37732",
   "metadata": {},
   "source": [
    "为了将文本数据喂入模型，需要将词转换为token。也就是把prompt转化成prompt_ids，把response转化成response_ids.\n",
    "\n",
    "同时，我们还需要将prompt_ids和response_ids拼接到一起作为模型的input_ids。\n",
    "\n",
    "这是为什么呢？\n",
    "\n",
    "因为ChatGLM2基座模型是一个Transformer结构，是一个被预选练过的纯粹的语言模型(LLM，Large Lauguage Model)。\n",
    "\n",
    "一个纯粹的语言模型，本质上只能做一件事情，那就是计算任意一段话像'人话'的概率。\n",
    "\n",
    "我们将prompt和response拼接到一起作为input_ids， ChatGLM2 就可以判断这段对话像'人类对话'的概率。\n",
    "\n",
    "在训练的时候我们使用梯度下降的方法来让ChatGLM2的判断更加准确。\n",
    "\n",
    "训练完成之后，在预测的时候，我们就可以利用贪心搜索或者束搜索的方法按照最像\"人类对话\"的方式进行更合理的文本生成。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "d5a6b462",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "\n",
    "# 定义最大序列长度，用于对文本进行截断或填充\n",
    "max_seq_length = 1024\n",
    "# 是否过滤\n",
    "skip_over_length = True\n",
    "\n",
    "config = transformers.AutoConfig.from_pretrained(model_name, trust_remote_code=True, device_map='auto')\n",
    "\n",
    "# 样本预处理\n",
    "def preprocess(example):\n",
    "    \n",
    "    context = example[\"prompt\"] \n",
    "    target = example[\"response\"]\n",
    "    \n",
    "    # 使用tokenizer对context进行编码，将其转换为数字id序列。\n",
    "    context_ids = tokenizer.encode(\n",
    "            context, \n",
    "            max_length=max_seq_length,\n",
    "            truncation=True)\n",
    "\n",
    "    # 使用tokenizer对target进行编码，将其转换为数字id序列，同时移除特殊标记。\n",
    "    target_ids = tokenizer.encode(\n",
    "        target,\n",
    "        max_length=max_seq_length,\n",
    "        truncation=True,\n",
    "        add_special_tokens=False)\n",
    "    \n",
    "    # 将context_ids、target_ids和config.eos_token_id组合成一个输入序列。\n",
    "    input_ids = context_ids + target_ids + [config.eos_token_id] # End of Sentence\n",
    "    \n",
    "    # -100标志位后面会在计算loss时会被忽略不贡献损失，我们集中优化target部分生成的loss\n",
    "    labels = [-100]*len(context_ids)+ target_ids + [config.eos_token_id]\n",
    "    \n",
    "    return {\"input_ids\": input_ids,\n",
    "            \"labels\": labels,\n",
    "            \"context_len\": len(context_ids),\n",
    "            'target_len':len(target_ids)+1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "61a3daa0-7275-4205-9d3f-014766b2eb11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "90c87bd1",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['input_ids', 'labels', 'context_len', 'target_len'],\n",
      "    num_rows: 217512\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "# 用于训练的Token保存后的文件名\n",
    "train_pickle_file = \"ds_train_token.pickle\"\n",
    "\n",
    "# 处理过程比较耗时，所以判断文件是否存在，可以直接读取\n",
    "if os.path.exists(train_pickle_file):\n",
    "    # 读取\n",
    "    with open(train_pickle_file, 'rb') as f:\n",
    "        ds_train_token = pickle.load(f)\n",
    "        # 打印基本信息\n",
    "        print(ds_train_token)\n",
    "else:\n",
    "    # 用preprocess函数对数据集中的每个样本进行预处理操作\n",
    "    ds_train_token = ds_train.map(preprocess).select_columns(['input_ids','labels', 'context_len','target_len'])\n",
    "\n",
    "    # 用于过滤的条件是example中的\"context_len\"列和\"target_len\"列要小于max_seq_length。只有满足这个条件的样本才会被保留下来。\n",
    "    if skip_over_length:\n",
    "        ds_train_token = ds_train_token.filter(\n",
    "            lambda example: example[\"context_len\"]<max_seq_length and example[\"target_len\"]<max_seq_length)\n",
    "    # 保存到磁盘\n",
    "    with open(train_pickle_file, 'wb') as f:\n",
    "        pickle.dump(ds_train_token, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "87d52a30",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['input_ids', 'labels', 'context_len', 'target_len'],\n",
      "    num_rows: 54420\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "# 用于测试的Token保存后的文件名\n",
    "test_pickle_file = \"ds_test_token.pickle\"\n",
    "if os.path.exists(test_pickle_file):\n",
    "    with open(test_pickle_file, 'rb') as f:\n",
    "        ds_test_token = pickle.load(f)\n",
    "        print(ds_test_token)\n",
    "else:\n",
    "    ds_test_token = ds_test.map(preprocess).select_columns(['input_ids', 'labels','context_len','target_len'])\n",
    "    if skip_over_length:\n",
    "        ds_val_token = ds_test_token.filter(\n",
    "            lambda example: example[\"context_len\"]<max_seq_length and example[\"target_len\"]<max_seq_length)\n",
    "    \n",
    "    with open(test_pickle_file, 'wb') as f:\n",
    "        pickle.dump(ds_test_token, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "287532fb",
   "metadata": {},
   "source": [
    "### 4 构建训练数据管道"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "90ccb6a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_collator(examples: list):\n",
    "    len_ids = [len(example[\"input_ids\"]) for example in examples]\n",
    "    longest = max(len_ids) #之后按照batch中最长的input_ids进行padding\n",
    "    \n",
    "    input_ids = []\n",
    "    labels_list = []\n",
    "    \n",
    "    for length, example in sorted(zip(len_ids, examples), key=lambda x: -x[0]):\n",
    "        ids = example[\"input_ids\"]\n",
    "        labs = example[\"labels\"]\n",
    "        \n",
    "        ids = ids + [tokenizer.pad_token_id] * (longest - length)\n",
    "        labs = labs + [-100] * (longest - length)\n",
    "        \n",
    "        input_ids.append(torch.LongTensor(ids))\n",
    "        labels_list.append(torch.LongTensor(labs))\n",
    "          \n",
    "    input_ids = torch.stack(input_ids)\n",
    "    labels = torch.stack(labels_list)\n",
    "    return {\n",
    "        \"input_ids\": input_ids,\n",
    "        \"labels\": labels,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "2cd45ee6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据加载器\n",
    "# 在训练过程中，GPU现存消耗大概在15G左右，如果不够，可以适当调小batch_size, num_workers 的大小\n",
    "dl_train = torch.utils.data.DataLoader(ds_train_token,num_workers=2,batch_size=8,\n",
    "                                       pin_memory=True,shuffle=True, collate_fn = data_collator)\n",
    "dl_test = torch.utils.data.DataLoader(ds_test_token,num_workers=2,batch_size=8,\n",
    "                                    pin_memory=True,shuffle=True, collate_fn = data_collator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "b7f32ebe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[64790, 64792, 30910,  ..., 49837,    13,     2],\n",
       "         [64790, 64792, 30910,  ...,     0,     0,     0],\n",
       "         [64790, 64792, 30910,  ...,     0,     0,     0],\n",
       "         ...,\n",
       "         [64790, 64792, 30910,  ...,     0,     0,     0],\n",
       "         [64790, 64792, 30910,  ...,     0,     0,     0],\n",
       "         [64790, 64792, 44497,  ...,     0,     0,     0]]),\n",
       " 'labels': tensor([[ -100,  -100,  -100,  ..., 49837,    13,     2],\n",
       "         [ -100,  -100,  -100,  ...,  -100,  -100,  -100],\n",
       "         [ -100,  -100,  -100,  ...,  -100,  -100,  -100],\n",
       "         ...,\n",
       "         [ -100,  -100,  -100,  ...,  -100,  -100,  -100],\n",
       "         [ -100,  -100,  -100,  ...,  -100,  -100,  -100],\n",
       "         [ -100,  -100,  -100,  ...,  -100,  -100,  -100]])}"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看一个Batch的数据\n",
    "for batch in dl_train:\n",
    "    break \n",
    "batch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5dc05df",
   "metadata": {},
   "source": [
    "### 5. 构建模型\n",
    "\n",
    "本示例使用QLoRA来调优ChatGLM2-6B模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "6dad32e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "3ee071aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel, TrainingArguments, AutoConfig\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from peft import get_peft_model, LoraConfig, TaskType\n",
    "\n",
    "model.supports_gradient_checkpointing = True  #节约cuda\n",
    "model.gradient_checkpointing_enable()\n",
    "model.enable_input_require_grads()\n",
    "\n",
    "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "6b5d5c68",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预处理量化模型，以适配LoRA调优\n",
    "from peft import prepare_model_for_kbit_training\n",
    "model = prepare_model_for_kbit_training(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "61a0ba8d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['dense', 'dense_h_to_4h', 'query_key_value', 'dense_4h_to_h']\n"
     ]
    }
   ],
   "source": [
    "# 找出所有的全连接，为全连接层添加LoRA适配器\n",
    "import bitsandbytes as bnb\n",
    "\n",
    "def find_all_linear_modules(model):\n",
    "    cls = bnb.nn.Linear4bit\n",
    "    lora_module_names = set()\n",
    "    \n",
    "    for name, module in model.named_modules():\n",
    "        if isinstance(module, cls):\n",
    "            names = name.split('.')\n",
    "            lora_module_names.add(name[0] if len(names) == 1 else names[-1])\n",
    "            \n",
    "    if \"lm_head\" in lora_module_names:\n",
    "        lora_module_names.remove(\"lm_head\")\n",
    "    return list(lora_module_names)\n",
    "\n",
    "lora_modules = find_all_linear_modules(model)\n",
    "print(lora_modules)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "e7ad5a78",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 1,949,696 || all params: 6,245,533,696 || trainable%: 0.031217444255383614\n"
     ]
    }
   ],
   "source": [
    "# 定义LoraConfig对象，配置LoRA适配器的参数\n",
    "peft_config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM,  # 任务类型，这里设置为TaskType.CAUSAL_LM，表示是一个因果语言建模任务。\n",
    "    inference_mode=False,          # 推断模式，这里设置为False，表示训练模式。\n",
    "    r=8,\n",
    "    lora_alpha=32, \n",
    "    lora_dropout=0.1,\n",
    "    # target_modules=lora_modules\n",
    ")\n",
    "\n",
    "# 根据配置获取PEFT模型\n",
    "model = get_peft_model(model, peft_config)\n",
    "\n",
    "# 设置模型为可并行的\n",
    "model.is_parallelizable = True\n",
    "\n",
    "# 设置模型为模型并行\n",
    "model.model_parallel = True\n",
    "\n",
    "# 打印可训练的参数\n",
    "model.print_trainable_parameters()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e4a88e0",
   "metadata": {},
   "source": [
    "### 6. 训练模型\n",
    "\n",
    "此处我们使用torchkeras工具包来构建训练流程。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "44958019",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchkeras import KerasModel \n",
    "from accelerate import Accelerator "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "569091f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class StepRunner:\n",
    "    def __init__(self, net, loss_fn, accelerator=None, stage = \"train\", metrics_dict = None, \n",
    "                 optimizer = None, lr_scheduler = None\n",
    "                 ):\n",
    "        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage\n",
    "        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler\n",
    "        self.accelerator = accelerator if accelerator is not None else Accelerator() \n",
    "        if self.stage=='train':\n",
    "            self.net.train() \n",
    "        else:\n",
    "            self.net.eval()\n",
    "    \n",
    "    def __call__(self, batch):\n",
    "        \n",
    "        #loss\n",
    "        with self.accelerator.autocast():\n",
    "            loss = self.net(input_ids=batch[\"input_ids\"],labels=batch[\"labels\"]).loss\n",
    "\n",
    "        #backward()\n",
    "        if self.optimizer is not None and self.stage==\"train\":\n",
    "            self.accelerator.backward(loss)\n",
    "            if self.accelerator.sync_gradients:\n",
    "                self.accelerator.clip_grad_norm_(self.net.parameters(), 1.0)\n",
    "            self.optimizer.step()\n",
    "            if self.lr_scheduler is not None:\n",
    "                self.lr_scheduler.step()\n",
    "            self.optimizer.zero_grad()\n",
    "            \n",
    "        all_loss = self.accelerator.gather(loss).sum()\n",
    "        \n",
    "        #losses (or plain metrics that can be averaged)\n",
    "        step_losses = {self.stage+\"_loss\":all_loss.item()}\n",
    "        \n",
    "        #metrics (stateful metrics)\n",
    "        step_metrics = {}\n",
    "        \n",
    "        if self.stage==\"train\":\n",
    "            if self.optimizer is not None:\n",
    "                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']\n",
    "            else:\n",
    "                step_metrics['lr'] = 0.0\n",
    "        return step_losses,step_metrics\n",
    "    \n",
    "KerasModel.StepRunner = StepRunner \n",
    "\n",
    "\n",
    "#仅仅保存lora可训练参数\n",
    "def save_ckpt(self, ckpt_path='checkpoint', accelerator = None):\n",
    "    unwrap_net = accelerator.unwrap_model(self.net)\n",
    "    unwrap_net.save_pretrained(ckpt_path)\n",
    "    \n",
    "def load_ckpt(self, ckpt_path='checkpoint'):\n",
    "    self.net = self.net.from_pretrained(self.net,ckpt_path)\n",
    "    self.from_scratch = False\n",
    "    \n",
    "KerasModel.save_ckpt = save_ckpt \n",
    "KerasModel.load_ckpt = load_ckpt "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "15d5a9d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "keras_model = KerasModel(model,loss_fn = None,optimizer=torch.optim.AdamW(model.parameters(),lr=2e-6))\n",
    "ckpt_path = 'meishi_chatglm2_qlora'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "7a7fdee4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31m<<<<<< ⚡️ cuda is used >>>>>>\u001b[0m\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGJCAYAAABYRTOkAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAA9hAAAPYQGoP6dpAAArKElEQVR4nO3de1TVVf7/8dcB5IAXQCUBlcRLmpqCYTJIZU4YZem4uoh2wbEcx7Iy6SZWolnSVDpOillNd+srpqPz/aZpSuqsEscJpNHyknmdfoGXFBINGs7+/eHqzJwABTywAZ+PtT5rcfbZ+3zeZ/fx8Gp/PueDwxhjBAAAYImP7QIAAMCFjTACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAlg2ffp0ORwOHT161HYp9Wb//v1yOBx666236nQMgMaBMAJcoGbNmqUVK1bYLgP/ZdOmTbrtttt08cUXq2XLlho4cKA2btxouyygzhFGgAsUYaThueOOO3Ts2DE98sgjevbZZ3X06FFdf/312rlzp+3SgDrlZ7sAAMAZixcvVlxcnPvxDTfcoB49emjZsmV64oknLFYG1C1WRoAG4ujRoxo5cqSCgoLUtm1bTZo0ST/++GOFfosWLVJsbKwCAwPVpk0bjRo1SocOHfLo8/XXX+uWW25ReHi4AgIC1LFjR40aNUpFRUWSJIfDoZKSEr399ttyOBxyOBz67W9/W2ldhYWF8vPz04wZMyo8t2vXLjkcDs2fP1+S9P333+uRRx5Rnz591LJlSwUFBemGG27QF198cZ6zU7VPPvlEV111lVq0aKGQkBD95je/0Y4dOzz6/PDDD3rooYcUFRUlp9Opdu3aaciQIcrLy3P3Odec1Yf/DiKSFBAQIEkqKyurtxoAG1gZARqIkSNHKioqShkZGdq8ebNeeuklHT9+XO+88467z7PPPqunnnpKI0eO1Lhx43TkyBHNmzdPV199tbZu3aqQkBCVlZUpKSlJpaWleuCBBxQeHq5vv/1WH374oU6cOKHg4GC9++67GjdunAYMGKDx48dLkrp27VppXWFhYRo0aJCWLFmi9PR0j+eysrLk6+ur2267TZK0d+9erVixQrfddps6d+6swsJCvfLKKxo0aJC++uortW/f3qtztm7dOt1www3q0qWLpk+frtOnT2vevHlKSEhQXl6eoqKiJEkTJkzQ0qVLdf/996tXr146duyYPv30U+3YsUOXX355teasKqdOndKpU6fOWauvr69at25d7ffmcrn08MMPy+l06o477qj2OKBRMgCsSk9PN5LM8OHDPdrvu+8+I8l88cUXxhhj9u/fb3x9fc2zzz7r0W/btm3Gz8/P3b5161YjyXzwwQdn3W+LFi3MmDFjqlXjK6+8YiSZbdu2ebT36tXL/PrXv3Y//vHHH015eblHn3379hmn02mefvppjzZJ5s0336zW/qsaExMTY9q1a2eOHTvmbvviiy+Mj4+PSUlJcbcFBwebiRMnVvna1Z2zyvz83+9cW6dOnWr0uuPHjzcOh8O8//77Na4JaGxYGQEaiIkTJ3o8fuCBB7RgwQKtWrVKffv21V/+8he5XC6NHDnS42vA4eHhuuSSS7R+/XpNnTrV/X/xa9as0dChQ9W8efPzru3mm2/WxIkTlZWVpcsuu0yStH37dn311VeaNGmSu5/T6XT/XF5erhMnTqhly5bq0aOHxykRb/juu++Un5+vxx57TG3atHG39+3bV0OGDNGqVavcbSEhIfr73/+u//f//l+lqzPnM2cpKSm68sorz9kvMDCw2q/5+uuv69VXX9WcOXM0evToao8DGivCCNBAXHLJJR6Pu3btKh8fH+3fv1/SmWsajDEV+v2sWbNmkqTOnTsrNTVVc+bM0XvvvaerrrpKw4cP15133nnW0w1nExoaqmuvvVZLlizRzJkzJZ05RePn56ebb77Z3c/lculPf/qTFixYoH379qm8vNz9XNu2bWu176ocOHBAktSjR48Kz/Xs2VNr1qxRSUmJWrRooeeff15jxoxRZGSkYmNjNXToUKWkpKhLly6Szm/OunTp4n4db3n33XfVvXt3TZ482auvCzRUXMAKNFAOh8PjscvlksPh0OrVq7V27doK2yuvvOLuO3v2bP3zn//U1KlTdfr0aT344IPq3bu3/vWvf9W6nlGjRmn37t3Kz8+XJC1ZskTXXnutQkND3X1mzZql1NRUXX311Vq0aJHWrFmjtWvXqnfv3nK5XLXe9/kaOXKk9u7dq3nz5ql9+/Z64YUX1Lt3b3300UfuPrWds5MnT6qgoOCc25EjR6pd77FjxxQREVHr9ws0NqyMAA3E119/rc6dO7sf79mzRy6Xy30RZteuXWWMUefOndW9e/dzvl6fPn3Up08fPfnkk9q0aZMSEhK0cOFCPfPMM5Iqhp1zGTFihH7/+98rKytLkrR7926lpaV59Fm6dKkGDx6s119/3aP9xIkTHqHFGzp16iTpzDd6fmnnzp0KDQ1VixYt3G0RERG67777dN999+nw4cO6/PLL9eyzz+qGG25w9znXnFXmxRdfrPSbRpXV+/Mq17mMHj3ao3agqSOMAA1EZmamrrvuOvfjefPmSZL7l+XNN9+stLQ0zZgxQ4sWLfIIE8YYff/992rbtq2Ki4vVvHlz+fn95593nz595OPjo9LSUndbixYtdOLEiWrXFxISoqSkJC1ZskTGGPn7+2vEiBEefXx9fWWM8Wj74IMP9O2336pbt27V3ld1REREKCYmRm+//bbS0tIUEhIi6cy1LB9//LHuvPNOSWeuXTl58qTH6ZZ27dqpffv27vmo7pxVpi6uGUlOTnafdgMuBIQRoIHYt2+fhg8fruuvv145OTlatGiRbr/9dkVHR0s6szLyzDPPKC0tTfv379eIESPUqlUr7du3T8uXL9f48eP1yCOP6JNPPtH999+v2267Td27d9e///1vvfvuu/L19dUtt9zi3l9sbKzWrVunOXPmqH379urcuXOF+1z8UnJysu68804tWLBASUlJ7gDws5tuuklPP/20xo4dq4EDB2rbtm167733vH5Nxc9eeOEF3XDDDYqPj9c999zj/mpvcHCwpk+fLunMPUY6duyoW2+9VdHR0WrZsqXWrVunf/zjH5o9e7YkVXvOKlMX14xce+21ioqK0oYNG7z6ukCDZfW7PADcXw396quvzK233mpatWplWrdube6//35z+vTpCv2XLVtmrrzyStOiRQvTokULc+mll5qJEyeaXbt2GWOM2bt3r7n77rtN165dTUBAgGnTpo0ZPHiwWbduncfr7Ny501x99dUmMDDQSKrW13yLi4vd/RctWlTh+R9//NE8/PDDJiIiwgQGBpqEhASTk5NjBg0aZAYNGuTu562v9hpjzLp160xCQoIJDAw0QUFBZtiwYearr75yP19aWmoeffRREx0dbVq1amVatGhhoqOjzYIFC9x9qjtn9aVTp04e8wU0dQ5jfrGmCgAAUI/4Ng0AALCKa0YAWFVWVqbvv//+rH2Cg4NrdAEogMaFMALAqk2bNmnw4MFn7fPmm29W+Yf8ADR+Vq8Z+dvf/qYXXnhBubm5+u6777R8+fIKXxX8pQ0bNig1NVVffvmlIiMj9eSTT/IhBTRix48fV25u7ln79O7dm5uAAU2Y1ZWRkpISRUdH6+677/a4pXRV9u3bpxtvvFETJkzQe++9p+zsbI0bN04RERFKSkqqh4oBeFvr1q2VmJhouwwAFjWYb9M4HI5zrow8/vjjWrlypbZv3+5uGzVqlE6cOKHVq1fXQ5UAAMDbGtU1Izk5ORX+DyopKUkPPfRQlWNKS0s97qDocrncd6qs6e2wAQC4kBlj9MMPP6h9+/by8fHeF3IbVRgpKChQWFiYR1tYWJiKi4t1+vTpSq+2z8jIqNbfjQAAANVz6NAhdezY0Wuv16jCSG2kpaUpNTXV/bioqEgXX3yxDh06pKCgIIuVAQDQuBQXFysyMlKtWrXy6us2qjASHh6uwsJCj7bCwkIFBQVVeQ8Cp9Mpp9NZoT0oKIgwAgBALXj7ModGdQfW+Ph4ZWdne7StXbtW8fHxlioCAADny2oYOXnypPLz85Wfny/pzFd38/PzdfDgQUlnTrGkpKS4+0+YMEF79+7VY489pp07d2rBggVasmSJJk+ebKN8AADgBVbDyOeff65+/fqpX79+kqTU1FT169dP06ZNkyR999137mAiSZ07d9bKlSu1du1aRUdHa/bs2frzn//MPUYAAGjEGsx9RupLcXGxgoODVVRUxDUjAADUQF39Dm1U14wAAICmhzACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACssh5GMjMzFRUVpYCAAMXFxWnLli1n7T937lz16NFDgYGBioyM1OTJk/Xjjz/WU7UAAMDbrIaRrKwspaamKj09XXl5eYqOjlZSUpIOHz5caf/3339fU6ZMUXp6unbs2KHXX39dWVlZmjp1aj1XDgAAvMVqGJkzZ45+97vfaezYserVq5cWLlyo5s2b64033qi0/6ZNm5SQkKDbb79dUVFRuu666zR69OhzrqYAAICGy1oYKSsrU25urhITE/9TjI+PEhMTlZOTU+mYgQMHKjc31x0+9u7dq1WrVmno0KFV7qe0tFTFxcUeGwAAaDj8bO346NGjKi8vV1hYmEd7WFiYdu7cWemY22+/XUePHtWVV14pY4z+/e9/a8KECWc9TZORkaEZM2Z4tXYAAOA91i9grYkNGzZo1qxZWrBggfLy8vSXv/xFK1eu1MyZM6sck5aWpqKiIvd26NCheqwYAACci7WVkdDQUPn6+qqwsNCjvbCwUOHh4ZWOeeqpp3TXXXdp3LhxkqQ+ffqopKRE48eP1xNPPCEfn4rZyul0yul0ev8NAAAAr7C2MuLv76/Y2FhlZ2e721wul7KzsxUfH1/pmFOnTlUIHL6+vpIkY0zdFQsAAOqMtZURSUpNTdWYMWPUv39/DRgwQHPnzlVJSYnGjh0rSUpJSVGHDh2UkZEhSRo2bJjmzJmjfv36KS4uTnv27NFTTz2lYcOGuUMJAABoXKyGkeTkZB05ckTTpk1TQUGBYmJitHr1avdFrQcPHvRYCXnyySflcDj05JNP6ttvv9VFF12kYcOG6dlnn7X1FgAAwHlymAvs/EZxcbGCg4NVVFSkoKAg2+UAANBo1NXv0Eb1bRoAAND0EEYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGCV9TCSmZmpqKgoBQQEKC4uTlu2bDlr/xMnTmjixImKiIiQ0+lU9+7dtWrVqnqqFgAAeJufzZ1nZWUpNTVVCxcuVFxcnObOnaukpCTt2rVL7dq1q9C/rKxMQ4YMUbt27bR06VJ16NBBBw4cUEhISP0XDwAAvMJhjDG2dh4XF6crrrhC8+fPlyS5XC5FRkbqgQce0JQpUyr0X7hwoV544QXt3LlTzZo1q9U+i4uLFRwcrKKiIgUFBZ1X/QAAXEjq6neotdM0ZWVlys3NVWJi4n+K8fFRYmKicnJyKh3zv//7v4qPj9fEiRMVFhamyy67TLNmzVJ5eXmV+yktLVVxcbHHBgAAGg5rYeTo0aMqLy9XWFiYR3tYWJgKCgoqHbN3714tXbpU5eXlWrVqlZ566inNnj1bzzzzTJX7ycjIUHBwsHuLjIz06vsAAADnx/oFrDXhcrnUrl07vfrqq4qNjVVycrKeeOIJLVy4sMoxaWlpKioqcm+HDh2qx4oBAMC5WLuANTQ0VL6+viosLPRoLywsVHh4eKVjIiIi1KxZM/n6+rrbevbsqYKCApWVlcnf37/CGKfTKafT6d3iAQCA11hbGfH391dsbKyys7PdbS6XS9nZ2YqPj690TEJCgvbs2SOXy+Vu2717tyIiIioNIgAAoOGzepomNTVVr732mt5++23t2LFD9957r0pKSjR27FhJUkpKitLS0tz97733Xn3//feaNGmSdu/erZUrV2rWrFmaOHGirbcAAADOk9X7jCQnJ+vIkSOaNm2aCgoKFBMTo9WrV7svaj148KB8fP6TlyIjI7VmzRpNnjxZffv2VYcOHTRp0iQ9/vjjtt4CAAA4T1bvM2ID9xkBAKB2mtx9RgAAACTCCAAAsIwwAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArKpVGHn77be1cuVK9+PHHntMISEhGjhwoA4cOOC14gAAQNNXqzAya9YsBQYGSpJycnKUmZmp559/XqGhoZo8ebJXCwQAAE2bX20GHTp0SN26dZMkrVixQrfccovGjx+vhIQEXXPNNd6sDwAANHG1Whlp2bKljh07Jkn6+OOPNWTIEElSQECATp8+7b3qAABAk1erlZEhQ4Zo3Lhx6tevn3bv3q2hQ4dKkr788ktFRUV5sz4AANDE1WplJDMzU/Hx8Tpy5IiWLVumtm3bSpJyc3M1evRorxYIAACaNocxxtguoj4VFxcrODhYRUVFCgoKsl0OAACNRl39Dq3Vysjq1av16aefuh9nZmYqJiZGt99+u44fP+614gAAQNNXqzDy6KOPqri4WJK0bds2Pfzwwxo6dKj27dun1NRUrxYIAACatlpdwLpv3z716tVLkrRs2TLddNNNmjVrlvLy8twXswIAAFRHrVZG/P39derUKUnSunXrdN1110mS2rRp414xAQAAqI5arYxceeWVSk1NVUJCgrZs2aKsrCxJ0u7du9WxY0evFggAAJq2Wq2MzJ8/X35+flq6dKlefvlldejQQZL00Ucf6frrr/dqgQAAoGnjq70AAKBa6up3aK1O00hSeXm5VqxYoR07dkiSevfureHDh8vX19drxQEAgKavVmFkz549Gjp0qL799lv16NFDkpSRkaHIyEitXLlSXbt29WqRAACg6arVNSMPPvigunbtqkOHDikvL095eXk6ePCgOnfurAcffNDbNQIAgCasVisjGzdu1ObNm9WmTRt3W9u2bfXcc88pISHBa8UBAICmr1YrI06nUz/88EOF9pMnT8rf3/+8iwIAABeOWoWRm266SePHj9ff//53GWNkjNHmzZs1YcIEDR8+3Ns1AgCAJqxWYeSll15S165dFR8fr4CAAAUEBGjgwIHq1q2b5s6d6+USAQBAU1ara0ZCQkL017/+VXv27HF/tbdnz57q1q2bV4sDAABNX7XDyLn+Gu/69evdP8+ZM6f2FQEAgAtKtcPI1q1bq9XP4XDUuhgAAHDhqXYY+e+VDwAAAG+p1QWsAAAA3kIYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFUNIoxkZmYqKipKAQEBiouL05YtW6o1bvHixXI4HBoxYkTdFggAAOqM9TCSlZWl1NRUpaenKy8vT9HR0UpKStLhw4fPOm7//v165JFHdNVVV9VTpQAAoC5YDyNz5szR7373O40dO1a9evXSwoUL1bx5c73xxhtVjikvL9cdd9yhGTNmqEuXLvVYLQAA8DarYaSsrEy5ublKTEx0t/n4+CgxMVE5OTlVjnv66afVrl073XPPPefcR2lpqYqLiz02AADQcFgNI0ePHlV5ebnCwsI82sPCwlRQUFDpmE8//VSvv/66XnvttWrtIyMjQ8HBwe4tMjLyvOsGAADeY/00TU388MMPuuuuu/Taa68pNDS0WmPS0tJUVFTk3g4dOlTHVQIAgJrws7nz0NBQ+fr6qrCw0KO9sLBQ4eHhFfp/88032r9/v4YNG+Zuc7lckiQ/Pz/t2rVLXbt29RjjdDrldDrroHoAAOANVldG/P39FRsbq+zsbHeby+VSdna24uPjK/S/9NJLtW3bNuXn57u34cOHa/DgwcrPz+cUDAAAjZDVlRFJSk1N1ZgxY9S/f38NGDBAc+fOVUlJicaOHStJSklJUYcOHZSRkaGAgABddtllHuNDQkIkqUI7AABoHKyHkeTkZB05ckTTpk1TQUGBYmJitHr1avdFrQcPHpSPT6O6tAUAANSAwxhjbBdRn4qLixUcHKyioiIFBQXZLgcAgEajrn6HsuQAAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrGkQYyczMVFRUlAICAhQXF6ctW7ZU2fe1117TVVddpdatW6t169ZKTEw8a38AANCwWQ8jWVlZSk1NVXp6uvLy8hQdHa2kpCQdPny40v4bNmzQ6NGjtX79euXk5CgyMlLXXXedvv3223quHAAAeIPDGGNsFhAXF6crrrhC8+fPlyS5XC5FRkbqgQce0JQpU845vry8XK1bt9b8+fOVkpJyzv7FxcUKDg5WUVGRgoKCzrt+AAAuFHX1O9TqykhZWZlyc3OVmJjobvPx8VFiYqJycnKq9RqnTp3STz/9pDZt2lT6fGlpqYqLiz02AADQcFgNI0ePHlV5ebnCwsI82sPCwlRQUFCt13j88cfVvn17j0Dz3zIyMhQcHOzeIiMjz7tuAADgPdavGTkfzz33nBYvXqzly5crICCg0j5paWkqKipyb4cOHarnKgEAwNn42dx5aGiofH19VVhY6NFeWFio8PDws4598cUX9dxzz2ndunXq27dvlf2cTqecTqdX6gUAAN5ndWXE399fsbGxys7Odre5XC5lZ2crPj6+ynHPP/+8Zs6cqdWrV6t///71USoAAKgjVldGJCk1NVVjxoxR//79NWDAAM2dO1clJSUaO3asJCklJUUdOnRQRkaGJOkPf/iDpk2bpvfff19RUVHua0tatmypli1bWnsfAACgdqyHkeTkZB05ckTTpk1TQUGBYmJitHr1avdFrQcPHpSPz38WcF5++WWVlZXp1ltv9Xid9PR0TZ8+vT5LBwAAXmD9PiP1jfuMAABQO03yPiMAAACEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWNUgwkhmZqaioqIUEBCguLg4bdmy5az9P/jgA1166aUKCAhQnz59tGrVqnqqFAAAeJv1MJKVlaXU1FSlp6crLy9P0dHRSkpK0uHDhyvtv2nTJo0ePVr33HOPtm7dqhEjRmjEiBHavn17PVcOAAC8wWGMMTYLiIuL0xVXXKH58+dLklwulyIjI/XAAw9oypQpFfonJyerpKREH374obvtV7/6lWJiYrRw4cJz7q+4uFjBwcEqKipSUFCQ994IAABNXF39DvXz2ivVQllZmXJzc5WWluZu8/HxUWJionJyciodk5OTo9TUVI+2pKQkrVixotL+paWlKi0tdT8uKiqSdGZCAQBA9f38u9Pb6xhWw8jRo0dVXl6usLAwj/awsDDt3Lmz0jEFBQWV9i8oKKi0f0ZGhmbMmFGhPTIyspZVAwBwYTt27JiCg4O99npWw0h9SEtL81hJOXHihDp16qSDBw96dSKbuuLiYkVGRurQoUOc3qom5qx2mLeaY85qh3mruaKiIl188cVq06aNV1/XahgJDQ2Vr6+vCgsLPdoLCwsVHh5e6Zjw8PAa9Xc6nXI6nRXag4ODOfhqISgoiHmrIeasdpi3mmPOaod5qzkfH+9+/8Xqt2n8/f0VGxur7Oxsd5vL5VJ2drbi4+MrHRMfH+/RX5LWrl1bZX8AANCwWT9Nk5qaqjFjxqh///4aMGCA5s6dq5KSEo0dO1aSlJKSog4dOigjI0OSNGnSJA0aNEizZ8/WjTfeqMWLF+vzzz/Xq6++avNtAACAWrIeRpKTk3XkyBFNmzZNBQUFiomJ0erVq90XqR48eNBjOWjgwIF6//339eSTT2rq1Km65JJLtGLFCl122WXV2p/T6VR6enqlp25QNeat5piz2mHeao45qx3mrebqas6s32cEAABc2KzfgRUAAFzYCCMAAMAqwggAALCKMAIAAKxqkmEkMzNTUVFRCggIUFxcnLZs2XLW/h988IEuvfRSBQQEqE+fPlq1alU9Vdqw1GTe3nrrLTkcDo8tICCgHqu1729/+5uGDRum9u3by+FwVPn3kf7bhg0bdPnll8vpdKpbt25666236rzOhqSmc7Zhw4YKx5nD4ajyzz80RRkZGbriiivUqlUrtWvXTiNGjNCuXbvOOe5C/1yrzbxd6J9rL7/8svr27eu+CVx8fLw++uijs47x1nHW5MJIVlaWUlNTlZ6erry8PEVHRyspKUmHDx+utP+mTZs0evRo3XPPPdq6datGjBihESNGaPv27fVcuV01nTfpzF0Lv/vuO/d24MCBeqzYvpKSEkVHRyszM7Na/fft26cbb7xRgwcPVn5+vh566CGNGzdOa9asqeNKG46aztnPdu3a5XGstWvXro4qbHg2btyoiRMnavPmzVq7dq1++uknXXfddSopKalyDJ9rtZs36cL+XOvYsaOee+455ebm6vPPP9evf/1r/eY3v9GXX35ZaX+vHmemiRkwYICZOHGi+3F5eblp3769ycjIqLT/yJEjzY033ujRFhcXZ37/+9/XaZ0NTU3n7c033zTBwcH1VF3DJ8ksX778rH0ee+wx07t3b4+25ORkk5SUVIeVNVzVmbP169cbSeb48eP1UlNjcPjwYSPJbNy4sco+fK5VVJ1543OtotatW5s///nPlT7nzeOsSa2MlJWVKTc3V4mJie42Hx8fJSYmKicnp9IxOTk5Hv0lKSkpqcr+TVFt5k2STp48qU6dOikyMvKs6RlncKzVXkxMjCIiIjRkyBB99tlntsuxqqioSJLO+ofKONYqqs68SXyu/ay8vFyLFy9WSUlJlX9uxZvHWZMKI0ePHlV5ebn77q0/CwsLq/Icc0FBQY36N0W1mbcePXrojTfe0F//+lctWrRILpdLAwcO1L/+9a/6KLlRqupYKy4u1unTpy1V1bBFRERo4cKFWrZsmZYtW6bIyEhdc801ysvLs12aFS6XSw899JASEhLOetdpPtc8VXfe+FyTtm3bppYtW8rpdGrChAlavny5evXqVWlfbx5n1m8Hj8YpPj7eIy0PHDhQPXv21CuvvKKZM2darAxNSY8ePdSjRw/344EDB+qbb77RH//4R7377rsWK7Nj4sSJ2r59uz799FPbpTQq1Z03PtfO/JvLz89XUVGRli5dqjFjxmjjxo1VBhJvaVIrI6GhofL19VVhYaFHe2FhocLDwysdEx4eXqP+TVFt5u2XmjVrpn79+mnPnj11UWKTUNWxFhQUpMDAQEtVNT4DBgy4II+z+++/Xx9++KHWr1+vjh07nrUvn2v/UZN5+6UL8XPN399f3bp1U2xsrDIyMhQdHa0//elPlfb15nHWpMKIv7+/YmNjlZ2d7W5zuVzKzs6u8pxXfHy8R39JWrt2bZX9m6LazNsvlZeXa9u2bYqIiKirMhs9jjXvyM/Pv6COM2OM7r//fi1fvlyffPKJOnfufM4xHGu1m7df4nPtzO+C0tLSSp/z6nFWi4trG7TFixcbp9Np3nrrLfPVV1+Z8ePHm5CQEFNQUGCMMeauu+4yU6ZMcff/7LPPjJ+fn3nxxRfNjh07THp6umnWrJnZtm2brbdgRU3nbcaMGWbNmjXmm2++Mbm5uWbUqFEmICDAfPnll7beQr374YcfzNatW83WrVuNJDNnzhyzdetWc+DAAWOMMVOmTDF33XWXu//evXtN8+bNzaOPPmp27NhhMjMzja+vr1m9erWtt1Dvajpnf/zjH82KFSvM119/bbZt22YmTZpkfHx8zLp162y9hXp37733muDgYLNhwwbz3XffubdTp065+/C5VlFt5u1C/1ybMmWK2bhxo9m3b5/55z//aaZMmWIcDof5+OOPjTF1e5w1uTBijDHz5s0zF198sfH39zcDBgwwmzdvdj83aNAgM2bMGI/+S5YsMd27dzf+/v6md+/eZuXKlfVcccNQk3l76KGH3H3DwsLM0KFDTV5enoWq7fn5a6e/3H6epzFjxphBgwZVGBMTE2P8/f1Nly5dzJtvvlnvddtU0zn7wx/+YLp27WoCAgJMmzZtzDXXXGM++eQTO8VbUtl8SfI4dvhcq6g283ahf67dfffdplOnTsbf399cdNFF5tprr3UHEWPq9jhzGGNMzddTAAAAvKNJXTMCAAAaH8IIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAavQ0bNsjhcOjEiRO2SwFQC4QRAABgFWEEAABYRRgBcN5cLpcyMjLUuXNnBQYGKjo6WkuXLpX0n1MoK1euVN++fRUQEKBf/epX2r59u8drLFu2TL1795bT6VRUVJRmz57t8Xxpaakef/xxRUZGyul0qlu3bnr99dc9+uTm5qp///5q3ry5Bg4cqF27dtXtGwfgFYQRAOctIyND77zzjhYuXKgvv/xSkydP1p133qmNGze6+zz66KOaPXu2/vGPf+iiiy7SsGHD9NNPP0k6EyJGjhypUaNGadu2bZo+fbqeeuopvfXWW+7xKSkp+p//+R+99NJL2rFjh1555RW1bNnSo44nnnhCs2fP1ueffy4/Pz/dfffd9fL+AZyn2v+xYQAw5scffzTNmzc3mzZt8mi/5557zOjRo8369euNJLN48WL3c8eOHTOBgYEmKyvLGGPM7bffboYMGeIx/tFHHzW9evUyxhiza9cuI8msXbu20hp+3se6devcbStXrjSSzOnTp73yPgHUHVZGAJyXPXv26NSpUxoyZIhatmzp3t555x1988037n7x8fHun9u0aaMePXpox44dkqQdO3YoISHB43UTEhL09ddfq7y8XPn5+fL19dWgQYPOWkvfvn3dP0dEREiSDh8+fN7vEUDd8rNdAIDG7eTJk5KklStXqkOHDh7POZ1Oj0BSW4GBgdXq16xZM/fPDodD0pnrWQA0bKyMADgvvXr1ktPp1MGDB9WtWzePLTIy0t1v8+bN7p+PHz+u3bt3q2fPnpKknj176rPPPvN43c8++0zdu3eXr6+v+vTpI5fL5XENCoCmg5URAOelVatWeuSRRzR58mS5XC5deeWVKioq0meffaagoCB16tRJkvT000+rbdu2CgsL0xNPPKHQ0FCNGDFCkvTwww/riiuu0MyZM5WcnKycnBzNnz9fCxYskCRFRUVpzJgxuvvuu/XSSy8pOjpaBw4c0OHDhzVy5Ehbbx2AlxBGAJy3mTNn6qKLLlJGRob27t2rkJAQXX755Zo6dar7NMlzzz2nSZMm6euvv1ZMTIz+7//+T/7+/pKkyy+/XEuWLNG0adM0c+ZMRURE6Omnn9Zvf/tb9z5efvllTZ06Vffdd5+OHTumiy++WFOnTrXxdgF4mcMYY2wXAaDp2rBhgwYPHqzjx48rJCTEdjkAGiCuGQEAAFYRRgAAgFWcpgEAAFaxMgIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACw6v8DjUyBEHENK0QAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 600x400 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "<style>\n",
       "    /* background: */\n",
       "    progress::-webkit-progress-bar {background-color: #CDCDCD; width: 100%;}\n",
       "    progress {background-color: #CDCDCD;}\n",
       "\n",
       "    /* value: */\n",
       "    progress::-webkit-progress-value {background-color: #00BFFF  !important;}\n",
       "    progress::-moz-progress-bar {background-color: #00BFFF  !important;}\n",
       "    progress {color: #00BFFF ;}\n",
       "\n",
       "    /* optional */\n",
       "    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
       "        background: #000000;\n",
       "    }\n",
       "</style>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      <progress value='0' class='' max='3' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      0% [0/3]\n",
       "      <br>\n",
       "                          0.21% [58/27189] [train_loss=2.71312,lr=0.00000]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn [36], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mkeras_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_data\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mdl_train\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m      2\u001b[0m \u001b[43m                \u001b[49m\u001b[43mval_data\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mdl_test\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m      3\u001b[0m \u001b[43m                \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m      4\u001b[0m \u001b[43m                \u001b[49m\u001b[43mpatience\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m#for early stop\u001b[39;49;00m\n\u001b[1;32m      5\u001b[0m \u001b[43m                \u001b[49m\u001b[43mmonitor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mval_loss\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m      6\u001b[0m \u001b[43m                \u001b[49m\u001b[43mmode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmin\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m      7\u001b[0m \u001b[43m                \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m      8\u001b[0m \u001b[43m                \u001b[49m\u001b[43mgradient_accumulation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m4\u001b[39;49m\n\u001b[1;32m      9\u001b[0m \u001b[43m                \u001b[49m\u001b[38;5;66;43;03m#mixed_precision='fp16'\u001b[39;49;00m\n\u001b[1;32m     10\u001b[0m \u001b[43m               \u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torchkeras/kerasmodel.py:205\u001b[0m, in \u001b[0;36mKerasModel.fit\u001b[0;34m(self, train_data, val_data, epochs, ckpt_path, patience, monitor, mode, callbacks, plot, wandb, quiet, mixed_precision, cpu, gradient_accumulation_steps)\u001b[0m\n\u001b[1;32m    203\u001b[0m train_epoch_runner \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mEpochRunner(train_step_runner,should_quiet)\n\u001b[1;32m    204\u001b[0m train_metrics \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mepoch\u001b[39m\u001b[38;5;124m'\u001b[39m:epoch}\n\u001b[0;32m--> 205\u001b[0m train_metrics\u001b[38;5;241m.\u001b[39mupdate(\u001b[43mtrain_epoch_runner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m    207\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m name, metric \u001b[38;5;129;01min\u001b[39;00m train_metrics\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m    208\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhistory[name] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhistory\u001b[38;5;241m.\u001b[39mget(name, []) \u001b[38;5;241m+\u001b[39m [metric]\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torchkeras/kerasmodel.py:77\u001b[0m, in \u001b[0;36mEpochRunner.__call__\u001b[0;34m(self, dataloader)\u001b[0m\n\u001b[1;32m     75\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step, batch \u001b[38;5;129;01min\u001b[39;00m loop: \n\u001b[1;32m     76\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnet):\n\u001b[0;32m---> 77\u001b[0m         step_losses,step_metrics \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msteprunner\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m   \n\u001b[1;32m     78\u001b[0m         step_log \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m(step_losses,\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mstep_metrics)\n\u001b[1;32m     79\u001b[0m         \u001b[38;5;28;01mfor\u001b[39;00m k,v \u001b[38;5;129;01min\u001b[39;00m step_losses\u001b[38;5;241m.\u001b[39mitems():\n",
      "Cell \u001b[0;32mIn [34], line 21\u001b[0m, in \u001b[0;36mStepRunner.__call__\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[38;5;66;03m#backward()\u001b[39;00m\n\u001b[1;32m     20\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptimizer \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstage\u001b[38;5;241m==\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m---> 21\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     22\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39msync_gradients:\n\u001b[1;32m     23\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39mclip_grad_norm_(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnet\u001b[38;5;241m.\u001b[39mparameters(), \u001b[38;5;241m1.0\u001b[39m)\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.8/site-packages/accelerate/accelerator.py:1636\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m   1634\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler\u001b[38;5;241m.\u001b[39mscale(loss)\u001b[38;5;241m.\u001b[39mbackward(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m   1635\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1636\u001b[0m     \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/_tensor.py:487\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    477\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    478\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    479\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m    480\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    485\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m    486\u001b[0m     )\n\u001b[0;32m--> 487\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    488\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m    489\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/lib/python3.8/site-packages/torch/autograd/__init__.py:200\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    195\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m    197\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m    198\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    199\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 200\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    201\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    202\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "keras_model.fit(train_data = dl_train,\n",
    "                val_data = dl_test,\n",
    "                epochs=3,\n",
    "                patience=3, #for early stop\n",
    "                monitor='val_loss',\n",
    "                mode='min',\n",
    "                ckpt_path = ckpt_path,\n",
    "                gradient_accumulation_steps=4\n",
    "                #mixed_precision='fp16'\n",
    "               )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5579ee48-9d9a-4cb6-a63f-fbe32d0f5513",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
